import os
import sys
from math import log10, copysign
import pybedtools
import pysam
from collections import Counter, defaultdict


if len(sys.argv) > 1:
    datasets = tuple(sys.argv[1:])
    filename = "peaks.%s.bed" % ("_".join(datasets))
else:
    datasets = ("HiSeq", "CAGE")
    filename = "peaks.bed"


keep_targets = set(['mRNA', 'lncRNA', 'gencode', 'fantomcat', 'genome',
                    'MALAT1', 'TERC', 'RMRP', 'RPPH', 'snhg',
                   ])

skip_targets = set(['chrM', 'rRNA', 'tRNA', 'snRNA', 'scRNA', 'snoRNA', 'yRNA',
                    'histone', 'scaRNA', 'snar', 'vRNA',
                   ])

keep_annotations = set(['FANTOM5_enhancer',
                        'roadmap_enhancer', 'roadmap_dyadic',
                        'novel_enhancer_CAGE', 'novel_enhancer_HiSeq',
                        'sense_proximal', 'sense_distal',
                        'sense_upstream', 'sense_distal_upstream',
                        'antisense', 'prompt',
                        'antisense_distal', 'antisense_distal_upstream',
                       ])

skip_annotations = set(['presnoRNA', 'prescaRNA', 'presnRNA', 'pretRNA'])


def find_libraries(dataset):
    libraries = []
    directory = "/osc-fs_home/mdehoon/Data/CASPARs/"
    subdirectory = os.path.join(directory, dataset, "Mapping")
    filenames = os.listdir(subdirectory)
    for filename in filenames:
        library, extension = os.path.splitext(filename)
        assert extension == ".bam"
        if dataset == "HiSeq" and library == "t01_r3":
            # negative control library using water as input material
            continue
        libraries.append(library)
    return sorted(libraries)


def parse_lines(dataset, library):
    directory = "/osc-fs_home/mdehoon/Data/CASPARs/%s/Mapping" % dataset
    filename = "%s.bam" % library
    path = os.path.join(directory, filename)
    print("Reading", path)
    alignments = pysam.AlignmentFile(path)
    if dataset in ("HiSeq", "CAGE", "StartSeq"):
        yield from alignments
    elif dataset == "MiSeq":
        for line1 in alignments:
            line2 = next(alignments)
            yield line1
    else:
        raise Exception("Unknown dataset %s" % dataset)
    alignments.close()

def analyze_bamfile(dataset, library):
    current = None
    lines = parse_lines(dataset, library)
    for line in lines:
        if line.is_unmapped:
            continue
        target = line.get_tag("XT")
        if target in skip_targets:
            continue
        assert target in keep_targets
        try:
            annotation = line.get_tag("XA")
        except KeyError:
            annotation = "other"
        else:
            if annotation in skip_annotations:
                continue
            assert annotation in keep_annotations
        try:
            gene = line.get_tag("XG")
        except KeyError:
            gene = ""
        multimap = line.get_tag("NH")
        if multimap != 1:
            continue
        if line.is_reverse:
            strand = '-'
            start = line.aend - 1
        else:
            strand = '+'
            start = line.pos
        chromosome = line.reference_name
        if (chromosome, start, strand) != current:
            if current is not None:
                target = ",".join(sorted(targets))
                score = str(count)
                fields = [current[0],       # chromosome
                          current[1],       # start
                          current[1] + 1,   # end
                          name,             # name
                          score,            # score
                          current[2]]       # strand
                interval = pybedtools.create_interval_from_list(fields)
                yield interval
            targets = set()
            count = 0.0
        current = chromosome, start, strand
        name = "%s|%s" % (annotation, gene)
        count += 1.0
    target = ",".join(sorted(targets))
    score = str(count)
    fields = [current[0],       # chromosome
              current[1],       # start
              current[1] + 1,   # end
              name,             # name
              score,            # score
              current[2]]       # strand
    interval = pybedtools.create_interval_from_list(fields)
    yield interval

def read_deseq_results(datasets):
    if datasets == ("HiSeq", "CAGE"):
        filename = "peaks.deseq.txt"
    else:
        filename = f"peaks.{'_'.join(datasets)}.deseq.txt"
    print("Reading", filename)
    handle = open(filename)
    line = next(handle)
    words = line.split()
    assert words[0] == "peak"
    if datasets == ("HiSeq", "CAGE"):
        n = 22
        assert words[1] == "00hr_basemean"
        assert words[2] == "00hr_log2fc"
        assert words[3] == "00hr_pvalue"
        assert words[4] == "01hr_basemean"
        assert words[5] == "01hr_log2fc"
        assert words[6] == "01hr_pvalue"
        assert words[7] == "04hr_basemean"
        assert words[8] == "04hr_log2fc"
        assert words[9] == "04hr_pvalue"
        assert words[10] == "12hr_basemean"
        assert words[11] == "12hr_log2fc"
        assert words[12] == "12hr_pvalue"
        assert words[13] == "24hr_basemean"
        assert words[14] == "24hr_log2fc"
        assert words[15] == "24hr_pvalue"
        assert words[16] == "96hr_basemean"
        assert words[17] == "96hr_log2fc"
        assert words[18] == "96hr_pvalue"
        assert words[19] == "all_basemean"
        assert words[20] == 'all_log2fc'
        assert words[21] == 'all_pvalue'
        index_log2fc = 20
        index_pvalue = 21
    else:
        n = 4
        assert words[1] == "basemean"
        assert words[2] == "log2fc"
        assert words[3] == "pvalue"
        index_log2fc = 2
        index_pvalue = 3
    assert len(words) == n
    ppvalues = {}
    for line in handle:
        words = line.split()
        assert len(words) == n
        peak_name = words[0]
        log2fc = float(words[index_log2fc])
        pvalue = float(words[index_pvalue])
        ppvalue = copysign(-log10(pvalue), log2fc)
        ppvalues[peak_name] = ppvalue
    handle.close()
    return ppvalues

ppvalues = read_deseq_results(datasets)

print("Reading", filename)
peaks = pybedtools.BedTool(filename)
peaks = peaks.sort()

counts = defaultdict(Counter)

for dataset in datasets:
    libraries = find_libraries(dataset)
    for library in libraries:
        alignments = analyze_bamfile(dataset, library)
        alignments = pybedtools.BedTool(alignments)
        alignments = alignments.saveas()
        alignments = alignments.sort()
        overlap = peaks.intersect(alignments, wa=True, wb=True, s=True)
        for line in overlap:
            fields = line.fields
            peak = pybedtools.create_interval_from_list(fields[:6])
            tag = pybedtools.create_interval_from_list(fields[6:])
            assert peak.strand == tag.strand
            name = peak.name
            annotation = tag.name
            count = float(tag.score)
            counts[name][annotation] += count


preferred_annotations = ('sense_proximal',
                         'sense_upstream',
                         'prompt',
                         'antisense',
                         'FANTOM5_enhancer',
                         'roadmap_enhancer',
                         'roadmap_dyadic',
                         'novel_enhancer_CAGE',
                         'novel_enhancer_HiSeq',
                         'sense_distal',
                         'sense_distal_upstream',
                         'antisense_distal',
                         'antisense_distal_upstream',
                         'other')

selected_annotations = {}
for name in counts:
    maximum = 0
    for annotation in counts[name]:
        count = counts[name][annotation]
        if count > maximum:
            maximum = count
            annotations = [annotation]
        elif count == maximum:
            annotations.append(annotation)
    annotations.sort()
    for preferred_annotation in preferred_annotations:
        for annotation in annotations:
            annotation, gene = annotation.split("|")
            if annotation == preferred_annotation:
                break
        else:
            continue
        break
    else:
        raise Exception("Failed to find preferred annotation in %s" % annotations)
    selected_annotations[name] = (annotation, gene)
 
first_dataset, second_dataset = sorted(datasets)

print("Reading", filename)
peaks = pybedtools.BedTool(filename)
filename = filename.replace(".bed", ".gff")
print("Writing", filename)
handle = open(filename, 'w')
for peak in peaks:
    name = peak.name
    ppvalue = ppvalues[name]
    if ppvalue < 0:
        dataset = first_dataset
    else:
        dataset = second_dataset
    annotation, gene = selected_annotations[name]
    fields = [peak.chrom, dataset, annotation, peak.start+1, peak.end, peak.score, peak.strand, ".", ""]
    line = pybedtools.create_interval_from_list(fields)
    line.attrs['ppvalue'] = "%.4f" % ppvalue
    if gene:
        line.attrs['gene'] = gene
    handle.write(str(line))
handle.close()
